{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# LLM-Blender Usage examples\n",
    "\n",
    "Requirements to run the following jupyter examples: (The default requirements supports inference with only minimal dependencies)\n",
    "```bash\n",
    "pip install -e .[example]\n",
    "```\n",
    "\n",
    "## Loading blender (quick start)\n",
    "You can find more custom configurations in \n",
    "- PairRanker: [./llm_blender/pair_ranker/config.py](./llm_blender/pair_ranker/config.py)\n",
    "- GenFuser: [./llm_blender/gen_fuser/config.py](./llm_blender/gen_fuser/config.py)\n",
    "- Blender: [./llm_blender/blender/config.py](./llm_blender/blender/config.py)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n",
    "import llm_blender\n",
    "blender = llm_blender.Blender()\n",
    "# Load Ranker\n",
    "blender.loadranker(\"llm-blender/PairRM\") # load ranker checkpoint\n",
    "# blender.loadranker(\"OpenAssistant/reward-model-deberta-v3-large-v2\") # load ranker checkpoint\n",
    "# Load Fuser\n",
    "# blender.loadfuser(\"llm-blender/gen_fuser_3b\") # load fuser checkpoint if you want to use pre-trained fuser; or you can use ranker only"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load Mixinstruct dataset for the following examples showing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import datasets\n",
    "import json\n",
    "from llm_blender.gpt_eval.cor_eval import COR_MAPS\n",
    "from llm_blender.gpt_eval.utils import get_ranks_from_chatgpt_cmps\n",
    "mixinstruct_test = datasets.load_dataset(\"llm-blender/mix-instruct\", split=\"test\", streaming=True)\n",
    "few_examples = list(mixinstruct_test.take(8))\n",
    "# remove cmp_results with none cmp results\n",
    "for ex in few_examples:\n",
    "    ex['cmp_results'] = json.loads(ex['cmp_results'])\n",
    "few_examples = [x for x in few_examples if x['cmp_results']]\n",
    "insts = [x['instruction'] for x in few_examples]\n",
    "inputs = [x['input'] for x in few_examples]\n",
    "candidates_texts = [[cand['text'] for cand in x['candidates']] for x in few_examples]\n",
    "print(\"Example:\")\n",
    "print(\"Instruction 1:\\n\", insts[0])\n",
    "print(\"Input 1:\\n\", inputs[0])\n",
    "print(\"Candidate 1 for input 1:\\n\")\n",
    "print(candidates_texts[0][0])"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Use case 1: Using LLM-Blender for ranking\n",
    "By the rank function, LLM-Blender could ranks the candidates through pairwise comparisons and return the ranks. We show our ranker's ranks are highly correlated with ChatGPT ranks."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ranks = blender.rank(inputs, candidates_texts, instructions=insts, return_scores=False, batch_size=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"Ranks for input 1:\", ranks[0]) # ranks of candidates for input 1\n",
    "# Ranks for input 1: [ 1 11  4  9 12  5  2  8  6  3 10  7]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "llm_ranks_map, gpt_cmp_results = get_ranks_from_chatgpt_cmps(few_examples)\n",
    "gpt_ranks = np.array(list(llm_ranks_map.values())).T\n",
    "print(\"Correlation with ChatGPT\")\n",
    "print(\"------------------------\")\n",
    "for cor_name, cor_func in COR_MAPS.items():\n",
    "    print(cor_name, cor_func(ranks, gpt_ranks))"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Use case 2: Using LLM-blender to directly compare two candidates"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "candidates_A = [x['candidates'][0]['text'] for x in few_examples]\n",
    "candidates_B = [x['candidates'][1]['text'] for x in few_examples]\n",
    "comparison_results = blender.compare(\n",
    "    inputs, candidates_A, candidates_B, instructions=insts, \n",
    "    batch_size=2, return_logits=False)\n",
    "print(\"Comparison results for inputs:\", comparison_results) # comparison results for input 1"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Use case 3: Using LLM-Blender for fuse generation\n",
    "We show that the the fused generation using the top-ranked candidate from the rankers could get outputs of higher quality."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from llm_blender.blender.blender_utils import get_topk_candidates_from_ranks\n",
    "topk_candidates = get_topk_candidates_from_ranks(ranks, candidates_texts, top_k=3)\n",
    "fuse_generations = blender.fuse(inputs, topk_candidates, instructions=insts, batch_size=2)\n",
    "print(\"fuse_generations for input 1:\", fuse_generations[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# # Or do rank and fuser together\n",
    "fuse_generations, ranks = blender.rank_and_fuse(inputs, candidates_texts, instructions=insts, return_scores=False, batch_size=2, top_k=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from llm_blender.common.evaluation import overall_eval\n",
    "metrics = ['bartscore']\n",
    "targets = [x['output'] for x in few_examples]\n",
    "scores = overall_eval(fuse_generations, targets, metrics)\n",
    "\n",
    "print(\"Fusion Scores\")\n",
    "for key, value in scores.items():\n",
    "    print(\"  \", key+\":\", np.mean(value))\n",
    "\n",
    "print(\"LLM Scores\")\n",
    "llms = [x['model'] for x in few_examples[0]['candidates']]\n",
    "llm_scores_map = {llm: {metric: [] for metric in metrics} for llm in llms}\n",
    "for ex in few_examples:\n",
    "    for cand in ex['candidates']:\n",
    "        for metric in metrics:\n",
    "            llm_scores_map[cand['model']][metric].append(cand['scores'][metric])\n",
    "for i, (llm, scores_map) in enumerate(llm_scores_map.items()):\n",
    "    print(f\"{i} {llm}\")\n",
    "    for metric, llm_scores in llm_scores_map[llm].items():\n",
    "        print(\"  \", metric+\":\", \"{:.4f}\".format(np.mean(llm_scores)))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Use case 4: Use LLM-Blender for decoding enhancement (best-of-n sampling)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"HuggingFaceH4/zephyr-7b-beta\")\n",
    "model = AutoModelForCausalLM.from_pretrained(\"HuggingFaceH4/zephyr-7b-beta\", device_map=\"auto\")\n",
    "\n",
    "system_message = {\n",
    "    \"role\": \"system\",\n",
    "    \"content\": \"You are a friendly chatbot who always responds in the style of a pirate\",\n",
    "}\n",
    "messages = [\n",
    "    [   \n",
    "        system_message,\n",
    "        {\"role\": \"user\", \"content\": _inst + \"\\n\" + _input},\n",
    "    ]\n",
    "    for _inst, _input in zip(insts, inputs)\n",
    "]\n",
    "prompts = [tokenizer.apply_chat_template(m, tokenize=False, add_generation_prompt=True) for m in messages]\n",
    "outputs = blender.best_of_n_generate(model, tokenizer, prompts, n=10)\n",
    "print(\"### Prompt:\")\n",
    "print(prompts[0])\n",
    "print(\"### best-of-n generations:\")\n",
    "print(outputs[0])\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Use case 5: Use PairRM for RLHF tuning\n",
    "\n",
    "To get scalar rewards, you can use `blender.rank_with_ref` method (see the example below)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rewards = blender.rank_with_ref(inputs, candidates_texts, return_scores=True, batch_size=2, mode=\"longest\")\n",
    "print(\"Rewards for input 1:\", rewards[0]) # rewards of candidates for input 1\n",
    "\"\"\"\n",
    "rewards is a List[List[float]] of shape (len(inputs), len(candidates_texts[0])).\n",
    "representing the rewards of each candidate for each input.\n",
    "By default, the rewards are calculated based on the the comparison with the longest generation as a reference.(mode=\"longest\").\n",
    "other supported modes are \"shortest\" \"median_length\" \"first\" \"last\"\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can also pass a list of references to compare with, instead of automatically selecting one from the candidates as the fixed reference.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ref_candidates = [_c[0] for _c in candidates_texts] # use the first candidate as the reference, same as mode=\"first\"\n",
    "rewards = blender.rank_with_ref(inputs, candidates_texts, return_scores=True, batch_size=2, ref_candidates=ref_candidates) \n",
    "\"\"\"\n",
    "ref_candidates = [ref1, ref2, ref3, ...] # ref_candidates is a List[str], shape (len(inputs),)\n",
    "this parameter will override the mode parameter, and use the ref_candidates as the reference for reward calculation.\n",
    "rewards is a List[List[float]] of shape (len(inputs), len(candidates_texts[0])).\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "llm_reranker",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
